In [1]:
import tensorflow as tf
import numpy as np
In [2]:
class MyRnnCell(tf.nn.rnn_cell.RNNCell):
def __init__(self, state_size, dtype):
self._state_size = state_size
self._dtype = dtype
self._W_xh = tf.get_variable(shape=[self._state_size, self._state_size],
dtype=self._dtype, name="W_xh", initializer=tf.truncated_normal_initializer())
self._W_hh = tf.get_variable(shape=[self._state_size, self._state_size],
dtype=self._dtype, name="W_hh", initializer=tf.truncated_normal_initializer())
self._W_ho = tf.get_variable(shape=[self._state_size, self._state_size],
dtype=self._dtype, name="W_ho", initializer=tf.truncated_normal_initializer())
self._b_o = tf.get_variable(shape=[self._state_size], dtype=self._dtype,
name="b_o", initializer=tf.truncated_normal_initializer())
def __call__(self, _input, state, scope=None):
new_state = tf.tanh(tf.matmul(_input, self._W_xh)+tf.matmul(state, self._W_hh))
new_output = tf.tanh(tf.matmul(new_state, self._W_ho)+self._b_o)
return new_output, new_state
@property
def output_size(self):
return self._state_size
@property
def state_size(self):
return self._state_size
In [3]:
tf.reset_default_graph()
test_cell = MyRnnCell(2, tf.float64)
In [4]:
sample_seq = np.array([[1,0],[0,1],[0,1]],dtype=np.float64)
sample_seq = np.concatenate([sample_seq]*(30), axis=0)
print("Sample sequence:\n{}".format(sample_seq))
train_input = sample_seq[0:5,:]
train_output = sample_seq[1:6,:]
test_input = sample_seq[:-1,:]
test_output = sample_seq[1:,:]
In [5]:
#state = np.zeros([1, 2])
inputs = tf.placeholder(shape=[None, 2], dtype=tf.float64)
targets = tf.placeholder(shape=[None, 2], dtype=tf.float64)
# One batch only
batch_inputs = tf.reshape(inputs, shape=np.array([1, -1, 2]))
outputs, final_state = tf.nn.dynamic_rnn(test_cell, batch_inputs, dtype=tf.float64)
# de-batch
outputs = tf.reshape(outputs, shape=[-1, 2])
loss = tf.nn.softmax_cross_entropy_with_logits(labels=targets, logits=outputs)
optimize_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)
print("Training network")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(20000):
sess.run([optimize_op, outputs], feed_dict={inputs: train_input, targets: train_output})
print("Testing network with input:\n{}".format(test_input))
print("Expected outputs:\n{}\nNetwork activations:\n{}".format(test_output,
sess.run(outputs, feed_dict={inputs: test_input})))
In [ ]: